# agnews_zne_utils.py
#
# ZNE utilities for AG-News distillation (adapted from your zne.py)

import numpy as np
from qiskit_aer import AerSimulator
from qiskit_aer.noise import NoiseModel, depolarizing_error
from qiskit.utils import QuantumInstance

from mitiq import zne, folding
from mitiq.zne.inference import RichardsonFactory


def build_noise_model(base_eps: float):
    nm = NoiseModel()
    nm.add_all_qubit_quantum_error(
        depolarizing_error(base_eps, 1),
        ['x','rx','ry','rz','u1','u2','u3']
    )
    nm.add_all_qubit_quantum_error(
        depolarizing_error(base_eps, 2),
        ['cx']
    )
    return nm


def make_executor(base_eps: float, seed: int = 123, shots: int = 1024):
    backend = AerSimulator(noise_model=build_noise_model(base_eps),
                           seed_simulator=seed)
    qinst = QuantumInstance(backend=backend, shots=shots,
                            seed_transpiler=seed)

    def _executor(circuit):
        counts = qinst.execute(circuit).get_counts()
        return counts.get('0'*circuit.num_qubits, 0) / shots

    return _executor


def zne_expectation_zero(vqc, x, base_eps: float, scale_factors=(1,3), seed: int = 123):
    executor = make_executor(base_eps, seed=seed)
    fold_global = folding.fold_global
    factory = RichardsonFactory(scale_factors=list(scale_factors), order=1)

    qc_list = vqc._neural_network.construct_circuit(x)
    zne_list = []

    for circ in qc_list:
        ez = zne.execute_with_zne(
            circ, executor, scale_noise=fold_global, factory=factory
        )
        ez = float(np.clip(ez, 1e-12, 1-1e-12))
        energy = (1.0 - ez)/2
        zne_list.append(energy)

    return np.array(zne_list, dtype=float)
